import json
import math

import torch
from torchvision import transforms as T, utils



from Image_Mediator_Training.imageMediator_graph import set_imageMediator
from ModularUtils.ControllerConstants import get_multiple_labels_fill
from ModularUtils.ControllerModel import get_generators, get_generated_labels
from ModularUtils.Experiment_Class import Experiment


import numpy as np
import matplotlib.pyplot as plt

from ModularUtils.Functions_Plot_Results import concat_horizon, concat_vertical

Exp = Experiment("Exp1", set_imageMediator,
                     Temperature=1,
                     temp_min=0.1,
                     G_hid_dims=[256, 256],
                     D_hid_dims=[256, 256],
                     IMAGE_FILTERS=[128, 64, 32],
                     CRITIC_ITERATIONS=5,
                     LAMBDA_GP=10,
                     learning_rate=5 * 1e-4,
                     Synthetic_Sample_Size=20000,
                     intv_Sample_Size=20000,
                     batch_size=200,
                     features=["feature"],
                     noise_states=64,
                     latent_state=4,
                     Data_intervs=[{}],
                     num_epochs=300,
                     new_experiment=False
                     )

Exp.intv_batch_size = Exp.batch_size


SHARED_INFO = "/path_to_project/SAVED_EXPERIMENTS/"+Exp.Complete_DAG_desc+"/SHARED_INFO.txt"
with open(SHARED_INFO) as f:
    data = f.read()
INSTANCE = json.loads(data)

Exp.load_which_models = {"medD": True, "I": True, "RI": True, "medC": True}


last_exp = "/path_to_project/SAVED_EXPERIMENTS/imageMediator/Exp1/May_01_2023-04_42"
Exp.LOAD_MODEL_PATH = last_exp
generators_full, _ = get_generators(Exp, Exp.load_which_models)
for gen in generators_full:
    generators_full[gen].eval()


last_exp = "/path_to_project/SAVED_EXPERIMENTS/imageMediator/Exp1/May_10_2023-08_54_fullrep"
Exp.LOAD_MODEL_PATH = last_exp
generators_rep, _ = get_generators(Exp, Exp.load_which_models)
for gen in generators_rep:
    generators_rep[gen].eval()


last_exp = "/path_to_project/SAVED_EXPERIMENTS/imageMediator/Exp1/May_10_2023-17_49_modular_parallel"
Exp.LOAD_MODEL_PATH = last_exp
generators_mod, _ = get_generators(Exp, Exp.load_which_models)
for gen in generators_mod:
    generators_mod[gen].eval()


with torch.no_grad():

    minibatch =6
    compare_Var=['medD']
    # medD= torch.randint(0,2, (minibatch,1))
    zeros= torch.zeros((int(minibatch/2),1))
    ones= torch.ones((int(minibatch/2),1))
    medD= torch.cat([zeros, ones],0)
    medD= get_multiple_labels_fill(Exp, medD, [2], isImage_labels=False)


    generated_labels_dict = get_generated_labels(Exp, generators_full, {}, {}, {'medD':medD}, compare_Var + [Exp.image_labels[0]],minibatch, hard=True)
    full_image = generated_labels_dict[Exp.image_labels[0]]

    generated_labels_dict = get_generated_labels(Exp, generators_rep, {}, {}, {'medD':medD}, compare_Var + [Exp.image_labels[0]],minibatch, hard=True)
    rep_image = generated_labels_dict[Exp.image_labels[0]]

    generated_labels_dict = get_generated_labels(Exp, generators_mod, {}, {}, {'medD':medD}, compare_Var + [Exp.image_labels[0]],minibatch, hard=True)
    mod_image = generated_labels_dict[Exp.image_labels[0]]




    """
    Combines N color images from a list of image paths.
    """
    output_col = None



    for id, img_array in enumerate([full_image, rep_image, mod_image]):
        #Full image learnt P(Image) but not any condition from the joint. Thus not matching any labels.

        output_row=None
        for i, img in enumerate(img_array):
            img= img.cpu().permute(1, 2, 0).numpy()
            if i == 0:
                output_row = img
            else:
                output_row = concat_horizon(output_row, img)

            h, w, c = img.shape
            if (i+1)%3==0:
                output_row = np.concatenate((output_row, np.ones((h, 2, 3))), axis=1)
            else:
                output_row = np.concatenate((output_row, np.ones((h, 1, 3))), axis=1)


        if id==0:
            output_col= output_row
        else:
            output_col= concat_vertical(output_col, output_row)

        h,w,c= output_row.shape
        output_col = np.concatenate((output_col, np.ones((1,w,3))),  axis=0)

    font1 = {'size': 16}


    plt.title("Colored-MNIST digits for frontdoor", fontdict=font1)
    plt.xlabel("(From Top) Row 1: NCM, Row 2: Rep,  Row 3: Ours. \n (From Left) Columns [1-3]: do(D=0), [4-6]: do(D=1)", fontdict=font1)
    plt.ylabel("Sampled images")

    plt.imshow(output_col)
    plt.show()
